import random
from operator import itemgetter

import numpy as np
import torch


def set_random_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.determinstic = True


def process_idx(loader):
    new_loader = []
    for batch in loader:
        clique_batch = []
        pos = 0
        for t in range(len(batch)):
            if t == 0:
                pos = 0
            else:
                pos += batch[t - 1].x.size(0)
            for j, (g, i) in enumerate(zip(batch[t].clique_idx, batch[t].clique_slice)):
                clique_batch.append(g[:i].add_(pos))
        batch.clique_batch = clique_batch
        new_loader.append(batch)
    return new_loader


def add_Inbatch(batch):
    clique_batch = []
    pos = 0
    for t in range(len(batch)):
        if t == 0:
            pos = 0
        else:
            pos += batch[t - 1].x.size(0)
        for j, (g, i) in enumerate(zip(batch[t].clique_idx, batch[t].clique_slice)):
            clique_batch.append(g[:i].add_(pos))
    batch.clique_batch = clique_batch
    return batch


def get_dict(graphs, clique_dict, all_edges):
    idxs = []
    for i in graphs:
        idxs.append(i.id)
    get_items = itemgetter(*idxs)
    cli_dict = get_items(clique_dict)
    edge_dict = get_items(all_edges)
    return cli_dict, edge_dict


def readout(clique_dict, all_edges, graphs):
    index = 0
    graphx = []
    for g in graphs:
        readout_result = []

        data = g.x.data
        x = clique_dict[index]
        for j in x:
            sum_result = sum([data[a] for a in j])
            sum_result = sum_result.cpu().numpy()
            readout_result.append(sum_result)
        readout_result = np.array(readout_result)
        readout_result = torch.Tensor(readout_result)
        g.x = readout_result
        g.edge_index = torch.Tensor(all_edges[index]).T.to(torch.long)
        if len(g.edge_index) != 0:
            g.edge_attr = torch.zeros(g.edge_index.size(1), 2)
            g.edge_attr[:, 0] = 6
            g.edge_attr[:, 1] = 3
            g.edge_attr = g.edge_attr.to(torch.long)
            num_atoms = g.x.shape[0]
            g.num_nodes = num_atoms
            graphx.append(g)
        index += 1
    return graphx


def read_book(cli_path, edge_path):
    file = open(cli_path, 'r', encoding='utf-8')
    data = file.readlines()
    all_edge = []
    edges = open(edge_path, 'r', encoding='utf-8')
    for e in edges:
        s = e.strip('\n')
        lst = eval(s)
        # 将元组转换为二维列表
        lst_2d = [[x, y] for x, y in lst]
        all_edge.append(lst_2d)
    clique = []
    cliques = []
    mol_index = 0
    for index, cli in enumerate(data):
        cli = cli.strip('\n')
        index_cli = (cli.split(' /'))
        indexs = index_cli[0]
        mol, cli_index = indexs.split(' ')
        mol, cli_index = int(mol), int(cli_index)
        if mol != mol_index:
            cliques.append(clique)
            clique = []
            c = eval(index_cli[1])
            clique.append(c)
            mol_index = mol
        else:
            c = eval(index_cli[1])
            clique.append(c)
    cliques.append(clique)
    return cliques, all_edge
